Skip to content

fix: guard fp32 lm-head logits to contiguous to avoid vLLM NaN#2710

Open
mvanhorn wants to merge 1 commit into
PrimeIntellect-ai:mainfrom
mvanhorn:fix/2497-fp32-lmhead-noncontiguous-logits-nan
Open

fix: guard fp32 lm-head logits to contiguous to avoid vLLM NaN#2710
mvanhorn wants to merge 1 commit into
PrimeIntellect-ai:mainfrom
mvanhorn:fix/2497-fp32-lmhead-noncontiguous-logits-nan

Conversation

@mvanhorn
Copy link
Copy Markdown

@mvanhorn mvanhorn commented Jun 4, 2026

Summary

The fp32 lm-head path in src/prime_rl/inference/patches.py slices the padded vocab dimension with logits[..., : self.org_vocab_size], which returns a non-contiguous view when padded_vocab > org_vocab_size. Adding .contiguous() after the slice makes the physical row stride equal org_vocab_size, so vLLM's Triton top-k/top-p kernel reads the correct rows.

Why this matters

Issue #2497 reported NaN log-probs on Olmo3 (org vocab 100278, padded 100288). vLLM's native Triton top-k/top-p kernel indexes rows as row_id * VOCAB_SIZE rather than by stride(0), so against the non-contiguous slice it read the wrong physical row, could mask a logical row to all -inf, and processed_logprobs then computed log_softmax(all -inf) = NaN. An upstream vLLM kernel fix was discussed but the issue remains open, and the reporter asked prime-rl to guard at this boundary since other logits processors can also produce non-contiguous views. The merged PR #2506 (padded_input_scrub) covers a different padded-decode-input path, not this lm-head slice. This fix keeps the fp32 dtype and math unchanged and only normalizes memory layout.

Testing

Covered by the new tests/unit/inference/test_fp32_lmhead_contiguous.py: a sliced padded-vocab tensor is contiguous with stride (org_vocab_size, 1) after the patch, the no-padding case is a no-op, per-row argmax/top values are preserved, and a synthetic top-p selection over the guarded logits no longer yields an all--inf row / NaN log-softmax. Full suite runs in CI.

Fixes #2497


Note

Low Risk
Narrow layout fix at the lm-head boundary with unit tests; fp32 math and slice semantics unchanged aside from memory layout.

Overview
Fixes NaN log-probs when the fp32 lm-head path trims logits from a padded vocabulary to org_vocab_size.

The patch replaces a bare slice with _trim_logits_to_org_vocab, which keeps the same values but calls .contiguous() so each row’s stride matches the logical vocab width. vLLM’s Triton top-k/top-p path assumes contiguous rows; a non-contiguous view (common when padded_vocab > org_vocab) could read the wrong memory and mask a row to all -inf, yielding NaN log_softmax.

New unit tests in test_fp32_lmhead_contiguous.py lock in layout (stride/contiguity), unchanged argmax/top-k, and that guarded logits avoid the all--inf / NaN failure mode under a synthetic top-p kernel.

Reviewed by Cursor Bugbot for commit a138b11. Bugbot is set up for automated code reviews on this repo. Configure here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

fp32 lm-head returns non-contiguous logits, triggers NaN in vLLM processed_logprobs

1 participant